{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c71ca3ef",
   "metadata": {},
   "source": [
    "# Basic training and fine-tuning with RAGatouille\n",
    "\n",
    "In this quick example, we'll use the `RAGTrainer` magic class to demonstrate how to very easily fine-tune an existing ColBERT model, or train one from any BERT/RoBERTa-like model (to [train one for a previously unsupported language like Japanese](https://huggingface.co/bclavie/jacolbert), for example!)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d6f69fca",
   "metadata": {},
   "source": [
    "First, we'll create an instance of `RAGtrainer`. We need to give it two arguments: the `model_name` we want to give to the model we're training, and the `pretrained_model_name` of the base model. This can either be a local path, or the name of a model on the HuggingFace Hub. If you're training for a language other than English, you should also include a `language_code` two-letter argument (PLACEHOLDER ISO) so we can get the relevant processing utils!\n",
    "\n",
    "The trainer will auto-detect whether it's an existing ColBERT model or a BERT base model, and will set itself up accordingly!\n",
    "\n",
    "Please note: Training can currently only be ran on GPU, and will error out if using CPU/MPS! Training is also currently not functional on Google Colab and Windows 10.\n",
    "\n",
    "Whether we're training from scratch or fine-tuning doesn't matter, all the steps are the same. For this example, let's fine-tune ColBERTv2:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1e81ac64-d222-412c-925c-2a5262266c0e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from ragatouille import RAGTrainer\n",
    "trainer = RAGTrainer(model_name=\"GhibliColBERT\", pretrained_model_name=\"colbert-ir/colbertv2.0\", language_code=\"en\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6fa8016a",
   "metadata": {},
   "source": [
    "To train retrieval models like colberts, we need training triplets: queries, positive passages, and negative passages for each query.\n",
    "\n",
    "In the next tutorial, we'll see [how to generate synthetic queries when we don't have any annotated passage]. For this tutorial, we'll assume that we have queries and relevant passages, but that we're lacking negative ones (because it's not an information we gather from our users).\n",
    "\n",
    "Let's assume our corpus is the same as the one we [used for our example about indexing an searching](PLACEHOLDER): Hayao Miyazaki's wikipedia page. Let's first fetch the content from Wikipedia:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3ca72a0d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import requests\n",
    "\n",
    "def get_wikipedia_page(title: str):\n",
    "    \"\"\"\n",
    "    Retrieve the full text content of a Wikipedia page.\n",
    "    \n",
    "    :param title: str - Title of the Wikipedia page.\n",
    "    :return: str - Full text content of the page as raw string.\n",
    "    \"\"\"\n",
    "    # Wikipedia API endpoint\n",
    "    URL = \"https://en.wikipedia.org/w/api.php\"\n",
    "\n",
    "    # Parameters for the API request\n",
    "    params = {\n",
    "        \"action\": \"query\",\n",
    "        \"format\": \"json\",\n",
    "        \"titles\": title,\n",
    "        \"prop\": \"extracts\",\n",
    "        \"explaintext\": True,\n",
    "    }\n",
    "\n",
    "    # Custom User-Agent header to comply with Wikipedia's best practices\n",
    "    headers = {\n",
    "        \"User-Agent\": \"RAGatouille_tutorial/0.0.1 (ben@clavie.eu)\"\n",
    "    }\n",
    "\n",
    "    response = requests.get(URL, params=params, headers=headers)\n",
    "    data = response.json()\n",
    "\n",
    "    # Extracting page content\n",
    "    page = next(iter(data['query']['pages'].values()))\n",
    "    return page['extract'] if 'extract' in page else None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8d8ade7f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "my_full_corpus = [get_wikipedia_page(\"Hayao_Miyazaki\")]\n",
    "my_full_corpus += [get_wikipedia_page(\"Studio_Ghibli\")]\n",
    "my_full_corpus += [get_wikipedia_page(\"Toei_Animation\")]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31778aed",
   "metadata": {},
   "source": [
    "We're also some Toei Animation content -- it helps to have things in our corpus that aren't directly relevant to our queries but are likely to cover similar topics, so we can get some more adjacent negative examples.\n",
    "\n",
    "The documents are very long, so let's use a `CorpusProcessor` to split them into chunks of around 256 tokens:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a14fe476",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from ragatouille.data import CorpusProcessor, llama_index_sentence_splitter\n",
    "\n",
    "corpus_processor = CorpusProcessor(document_splitter_fn=llama_index_sentence_splitter)\n",
    "documents = corpus_processor.process_corpus(my_full_corpus, chunk_size=256)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a26eea4c",
   "metadata": {},
   "source": [
    "Now that we have a corpus of documents, let's generate fake query-relevant passage pair. Obviously, you wouldn't want that in the real world, but that's the topic of the next tutorial:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "53b48c50",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import random\n",
    "\n",
    "queries = [\"What manga did Hayao Miyazaki write?\",\n",
    "           \"which film made ghibli famous internationally\",\n",
    "           \"who directed Spirited Away?\",\n",
    "           \"when was Hikotei Jidai published?\",\n",
    "           \"where's studio ghibli based?\",\n",
    "           \"where is the ghibli museum?\"\n",
    "] * 3\n",
    "pairs = []\n",
    "\n",
    "for query in queries:\n",
    "    fake_relevant_docs = random.sample(documents, 10)\n",
    "    for doc in fake_relevant_docs:\n",
    "        pairs.append((query, doc))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "46f29b15",
   "metadata": {},
   "source": [
    "Here, we have created pairs.It's common for retrieval training data to be stored in a lot of different ways: pairs of [query, positive], pairs of [query, passage, label], triplets of [query, positive, negative], or triplets of [query, list_of_positives, list_of_negatives]. No matter which format your data's in, you don't need to worry about it: RAGatouille will generate ColBERT-friendly triplets for you, and export them to disk for easy `dvc` or `wandb` data tracking.\n",
    "\n",
    "Speaking of, let's process the data so it's ready for training. `RAGTrainer` has a `prepare_training_data` function, which will perform all the necessary steps. One of the steps it performs is called **hard negative mining**: that's searching the full collection of documents (even those not linked to a query) to retrieve passages that are semantically close to a query, but aren't actually relevant. Using those to train retrieval models has repeatedly been shown to greatly improve their ability to find actually relevant documents, so it's a very important step! \n",
    "\n",
    "RAGatouille handles all of this for you. By default, it'll fetch 10 negative examples per query, but you can customise this with `num_new_negatives`. You can also choose not to mine negatives and just sample random examples to speed up things, this might lower performance but will run done much quicker on large volumes of data, just set `mine_hard_negatives` to `False`. If you've already mined negatives yourself, you can set `num_new_negatives` to 0 to bypass this entirely."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1ddaf3fb",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading Hard Negative SimpleMiner dense embedding model BAAI/bge-small-en-v1.5...\n",
      "Building hard negative index for 93 documents...\n",
      "All documents embedded, now adding to index...\n",
      "save_index set to False, skipping saving hard negative index\n",
      "Hard negative index generated\n",
      "mining\n",
      "mining\n",
      "mining\n",
      "mining\n",
      "mining\n",
      "mining\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'./data/'"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.prepare_training_data(raw_data=pairs, data_out_path=\"./data/\", all_documents=my_full_corpus, num_new_negatives=10, mine_hard_negatives=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b9468e25",
   "metadata": {},
   "source": [
    "Our training data's now fully processed and saved to disk in `data_out_path`! We're now ready to begin training our model with the `train` function. `train` takes many arguments, but the set of default is already fairly strong!\n",
    "\n",
    "Don't be surprised you don't see an `epochs` parameter here, ColBERT will train until it either reaches `maxsteps` or has seen the entire training data once (a full epoch), this is by design!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "06773f75-c844-42a4-b786-e56ef899a96e",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#> Starting...\n",
      "nranks = 1 \t num_gpus = 1 \t device=0\n",
      "{\n",
      "    \"query_token_id\": \"[unused0]\",\n",
      "    \"doc_token_id\": \"[unused1]\",\n",
      "    \"query_token\": \"[Q]\",\n",
      "    \"doc_token\": \"[D]\",\n",
      "    \"ncells\": null,\n",
      "    \"centroid_score_threshold\": null,\n",
      "    \"ndocs\": null,\n",
      "    \"load_index_with_mmap\": false,\n",
      "    \"index_path\": null,\n",
      "    \"nbits\": 4,\n",
      "    \"kmeans_niters\": 20,\n",
      "    \"resume\": false,\n",
      "    \"similarity\": \"cosine\",\n",
      "    \"bsize\": 32,\n",
      "    \"accumsteps\": 1,\n",
      "    \"lr\": 5e-6,\n",
      "    \"maxsteps\": 500000,\n",
      "    \"save_every\": 0,\n",
      "    \"warmup\": 0,\n",
      "    \"warmup_bert\": null,\n",
      "    \"relu\": false,\n",
      "    \"nway\": 2,\n",
      "    \"use_ib_negatives\": true,\n",
      "    \"reranker\": false,\n",
      "    \"distillation_alpha\": 1.0,\n",
      "    \"ignore_scores\": false,\n",
      "    \"model_name\": \"GhibliColBERT\",\n",
      "    \"query_maxlen\": 32,\n",
      "    \"attend_to_mask_tokens\": false,\n",
      "    \"interaction\": \"colbert\",\n",
      "    \"dim\": 128,\n",
      "    \"doc_maxlen\": 256,\n",
      "    \"mask_punctuation\": true,\n",
      "    \"checkpoint\": \"colbert-ir\\/colbertv2.0\",\n",
      "    \"triples\": \"data\\/triples.train.colbert.jsonl\",\n",
      "    \"collection\": \"data\\/corpus.train.colbert.tsv\",\n",
      "    \"queries\": \"data\\/queries.train.colbert.tsv\",\n",
      "    \"index_name\": null,\n",
      "    \"overwrite\": false,\n",
      "    \"root\": \".ragatouille\\/\",\n",
      "    \"experiment\": \"colbert\",\n",
      "    \"index_root\": null,\n",
      "    \"name\": \"2024-01\\/03\\/17.31.59\",\n",
      "    \"rank\": 0,\n",
      "    \"nranks\": 1,\n",
      "    \"amp\": true,\n",
      "    \"gpus\": 1\n",
      "}\n",
      "Using config.bsize = 32 (per process) and config.accumsteps = 1\n",
      "[Jan 03, 17:32:08] #> Loading the queries from data/queries.train.colbert.tsv ...\n",
      "[Jan 03, 17:32:08] #> Got 6 queries. All QIDs are unique.\n",
      "\n",
      "[Jan 03, 17:32:08] #> Loading collection...\n",
      "0M "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.10/site-packages/transformers/optimization.py:429: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
      "  warnings.warn(\n",
      "/opt/conda/lib/python3.10/site-packages/torch/optim/lr_scheduler.py:139: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate\n",
      "  warnings.warn(\"Detected call of `lr_scheduler.step()` before `optimizer.step()`. \"\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "#> LR will use 0 warmup steps and linear decay over 500000 steps.\n",
      "\n",
      "#> QueryTokenizer.tensorize(batch_text[0], batch_background[0], bsize) ==\n",
      "#> Input: . What manga did Hayao Miyazaki write?, \t\t True, \t\t None\n",
      "#> Output IDs: torch.Size([32]), tensor([  101,     1,  2054,  8952,  2106, 10974,  7113,  2771,  3148, 18637,\n",
      "         4339,  1029,   102,   103,   103,   103,   103,   103,   103,   103,\n",
      "          103,   103,   103,   103,   103,   103,   103,   103,   103,   103,\n",
      "          103,   103], device='cuda:0')\n",
      "#> Output Mask: torch.Size([32]), tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "        0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')\n",
      "\n",
      "\t\t\t\t 4.857693195343018 8.471616744995117\n",
      "#>>>    16.0 20.45 \t\t|\t\t -4.449999999999999\n",
      "[Jan 03, 17:32:12] 0 13.329309463500977\n",
      "\t\t\t\t 5.369233131408691 9.21822452545166\n",
      "#>>>    15.58 20.6 \t\t|\t\t -5.020000000000001\n",
      "[Jan 03, 17:32:12] 1 13.330567611694336\n",
      "\t\t\t\t 9.961652755737305 15.010833740234375\n",
      "#>>>    10.04 19.99 \t\t|\t\t -9.95\n",
      "[Jan 03, 17:32:13] 2 13.342209530578614\n",
      "\t\t\t\t 6.136704921722412 12.674407005310059\n",
      "#>>>    11.99 17.86 \t\t|\t\t -5.869999999999999\n",
      "[Jan 03, 17:32:14] 3 13.347678432498231\n",
      "\t\t\t\t 8.609644889831543 13.769636154174805\n",
      "#>>>    13.18 21.75 \t\t|\t\t -8.57\n",
      "[Jan 03, 17:32:15] 4 13.356710034156064\n",
      "[Jan 03, 17:32:15] #> Done with all triples!\n",
      "#> Saving a checkpoint to .ragatouille/colbert/none/2024-01/03/17.31.59/checkpoints/colbert ..\n",
      "#> Joined...\n"
     ]
    }
   ],
   "source": [
    "\n",
    "trainer.train(batch_size=32,\n",
    "              nbits=4, # How many bits will the trained model use when compressing indexes\n",
    "              maxsteps=500000, # Maximum steps hard stop\n",
    "              use_ib_negatives=True, # Use in-batch negative to calculate loss\n",
    "              dim=128, # How many dimensions per embedding. 128 is the default and works well.\n",
    "              learning_rate=5e-6, # Learning rate, small values ([3e-6,3e-5] work best if the base model is BERT-like, 5e-6 is often the sweet spot)\n",
    "              doc_maxlen=256, # Maximum document length. Because of how ColBERT works, smaller chunks (128-256) work very well.\n",
    "              use_relu=False, # Disable ReLU -- doesn't improve performance\n",
    "              warmup_steps=\"auto\", # Defaults to 10%\n",
    "             )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "44ab6a51-a4bd-4fab-96d7-dd8e93dac462",
   "metadata": {},
   "source": [
    "And you're now done training! Your model is saved at the path it outputs, with the final checkpoint always being in the `.../checkpoints/colbert` path, and intermediate checkpoints saved at `.../checkpoints/colbert-{N_STEPS}`.\n",
    "\n",
    "You can now use your model by pointing at its local path, or upload it to the huggingface hub to share it with the world!"
   ]
  }
 ],
 "metadata": {
  "environment": {
   "kernel": "python3",
   "name": ".m114",
   "type": "gcloud",
   "uri": "gcr.io/deeplearning-platform-release/:m114"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
